{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nCompile Keras Models\n=====================\n**Author**: `Yuwei Hu <https://Huyuwei.github.io/>`_\n\nThis article is an introductory tutorial to deploy keras models with Relay.\n\nFor us to begin with, keras should be installed.\nTensorflow is also required since it's used as the default backend of keras.\n\nA quick solution is to install via pip\n\n.. code-block:: bash\n\n    pip install -U keras --user\n    pip install -U tensorflow --user\n\nor please refer to official site\nhttps://keras.io/#installation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\nfrom tvm import te\nimport tvm.relay as relay\nfrom tvm.contrib.download import download_testdata\nimport keras\nimport numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Load pretrained keras model\n----------------------------\nWe load a pretrained resnet-50 classification model provided by keras.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if tuple(keras.__version__.split(\".\")) < (\"2\", \"4\", \"0\"):\n    weights_url = \"\".join(\n        [\n            \"https://github.com/fchollet/deep-learning-models/releases/\",\n            \"download/v0.2/resnet50_weights_tf_dim_ordering_tf_kernels.h5\",\n        ]\n    )\n    weights_file = \"resnet50_keras_old.h5\"\nelse:\n    weights_url = \"\".join(\n        [\n            \" https://storage.googleapis.com/tensorflow/keras-applications/\",\n            \"resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5\",\n        ]\n    )\n    weights_file = \"resnet50_keras_new.h5\"\n\n\nweights_path = download_testdata(weights_url, weights_file, module=\"keras\")\nkeras_resnet50 = keras.applications.resnet50.ResNet50(\n    include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000\n)\nkeras_resnet50.load_weights(weights_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Load a test image\n------------------\nA single cat dominates the examples!\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from PIL import Image\nfrom matplotlib import pyplot as plt\nfrom keras.applications.resnet50 import preprocess_input\n\nimg_url = \"https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true\"\nimg_path = download_testdata(img_url, \"cat.png\", module=\"data\")\nimg = Image.open(img_path).resize((224, 224))\nplt.imshow(img)\nplt.show()\n# input preprocess\ndata = np.array(img)[np.newaxis, :].astype(\"float32\")\ndata = preprocess_input(data).transpose([0, 3, 1, 2])\nprint(\"input_1\", data.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Compile the model with Relay\n----------------------------\nconvert the keras model(NHWC layout) to Relay format(NCHW layout).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "shape_dict = {\"input_1\": data.shape}\nmod, params = relay.frontend.from_keras(keras_resnet50, shape_dict)\n# compile the model\ntarget = \"cuda\"\ndev = tvm.cuda(0)\n\n# TODO(mbs): opt_level=3 causes nn.contrib_conv2d_winograd_weight_transform\n# to end up in the module which fails memory validation on cuda most likely\n# due to a latent bug. Note that the pass context only has an effect within\n# evaluate() and is not captured by create_executor().\nwith tvm.transform.PassContext(opt_level=0):\n    model = relay.build_module.create_executor(\"graph\", mod, dev, target, params).evaluate()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Execute on TVM\n---------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dtype = \"float32\"\ntvm_out = model(tvm.nd.array(data.astype(dtype)))\ntop1_tvm = np.argmax(tvm_out.numpy()[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Look up synset name\n-------------------\nLook up prediction top 1 index in 1000 class synset.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "synset_url = \"\".join(\n    [\n        \"https://gist.githubusercontent.com/zhreshold/\",\n        \"4d0b62f3d01426887599d4f7ede23ee5/raw/\",\n        \"596b27d23537e5a1b5751d2b0481ef172f58b539/\",\n        \"imagenet1000_clsid_to_human.txt\",\n    ]\n)\nsynset_name = \"imagenet1000_clsid_to_human.txt\"\nsynset_path = download_testdata(synset_url, synset_name, module=\"data\")\nwith open(synset_path) as f:\n    synset = eval(f.read())\nprint(\"Relay top-1 id: {}, class name: {}\".format(top1_tvm, synset[top1_tvm]))\n# confirm correctness with keras output\nkeras_out = keras_resnet50.predict(data.transpose([0, 2, 3, 1]))\ntop1_keras = np.argmax(keras_out)\nprint(\"Keras top-1 id: {}, class name: {}\".format(top1_keras, synset[top1_keras]))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}